{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MutableChannelUnit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each MutableChannelUnit is a basic unit for pruning. It records all channels which are dependent on each other.\n",
    "Below, we will introduce you about:\n",
    "1. The data structure of MutableChannelUnit.\n",
    "2. How to prune the model with a MutableChannelUnit.\n",
    "3. How to get MutableChannelUnits.\n",
    "4. How to develop a new MutableChannelUnit for a new pruning algorithm.\n",
    "<p align=\"center\"><img src=\"../../../../../docs/en/imgs/pruning/unit.png\" alt=\"MutableChannelUnit\" width=\"800\"></p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Data Structure of MutableChannelUnit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's parse a model and get several MutableChannelUnits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a model\n",
    "from mmengine.model import BaseModel\n",
    "from torch import nn\n",
    "from collections import OrderedDict\n",
    "\n",
    "class MyModel(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            OrderedDict([('conv0', nn.Conv2d(3, 8, 3, 1, 1)),\n",
    "                         ('relu', nn.ReLU()),\n",
    "                         ('conv1', nn.Conv2d(8, 16, 3, 1, 1))]))\n",
    "        self.pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.head = nn.Linear(16, 1000)\n",
    "\n",
    "    def forward(self, x):\n",
    "        feature = self.net(x)\n",
    "        pool = self.pool(feature).flatten(1)\n",
    "        return self.head(pool)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# There are multiple types of MutableChannelUnits. Here, We take SequentialMutableChannelUnit as the example.\n",
    "from mmrazor.models.mutables.mutable_channel.units import SequentialMutableChannelUnit\n",
    "from mmrazor.structures.graph import ModuleGraph\n",
    "from typing import List\n",
    "\n",
    "model = MyModel()\n",
    "units: List[\n",
    "    SequentialMutableChannelUnit] = SequentialMutableChannelUnit.init_from_channel_analyzer(model)  # type: ignore\n",
    "print(\n",
    "    f'This model has {len(units)} MutableChannelUnit(SequentialMutableChannelUnit).'\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unit1=units[1]\n",
    "print(unit1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As shown above, each MutableChannelUnit has four important attributes: \n",
    "1. name: str\n",
    "2. output_related: ModuleList\n",
    "3. input_related: ModuleList\n",
    "4. mutable_channel: BaseMutableChannel\n",
    "\n",
    "\"name\" is the identifier of the MutableChannelUnit. It's automatically generated usually.\n",
    "\n",
    "\"output_related\" and \"input_related\" are two ModuleLists. They store all Channels with channel dependency.\n",
    "The difference is that the \"output_related\" includes output channels and the \"input_related\" includes input channels.\n",
    "All these channels\n",
    "\n",
    "\"mutable_channel\" is a BaseMutableChannel used to control the channel mask of modules. The mutable_channel is registered to the modules whose channels are stored in \"output_related\" and \"input_related\"."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to prune the model with a MutableChannelUnit."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are three steps to prune the model using a MutableChannelUnit:\n",
    "1. replace modules, whose channel are stored in the \"output_related\" and \"input_related\", with dynamic ops which are able to deal with mutable number of channels.\n",
    "2. register the \"mutable_channel\" to the replaced dynamic ops.\n",
    "3. change the choice of the \"mutable_channel\".\n",
    "\n",
    "For simplicity, we run step 1 and 2 with one method \"prepare_for_pruning\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We run \"prepare_for_pruning\" once before pruning to run step 1 and 2 above.\n",
    "unit1.prepare_for_pruning(model)\n",
    "print(f'The current choice of unit1 is {unit1.current_choice}.')\n",
    "print(model.net.conv0)\n",
    "print(model.net.conv1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We prune the model by changing the current_choice of the MutableChannelUnits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_choice=unit1.sample_choice()\n",
    "print(f'We get a sampled choice {sampled_choice}.')\n",
    "unit1.current_choice=sampled_choice\n",
    "print(model.net.conv0)\n",
    "print(model.net.conv1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Besides, different types of MutableChannelUnit may have different types of choices. Please read documents for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to get MutableChannelUnits."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are three ways to get MutableChannelUnits.\n",
    "1. Using a tracer.\n",
    "   This way, firstly, converts a model to a graph, then converts the graph to MutableChannelUnits. It automatically returns all available MutableChannelUnits.\n",
    "2. Using a config.\n",
    "   This way uses a config to initialize a MutableChannelUnit.\n",
    "3. Using a predefined model.\n",
    "   This way parses a predefined model with dynamic ops. It returns all available MutableChannelUnits.\n",
    "\n",
    "All these three ways have corresponding documents in the README of ChannelMutator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. using tracer\n",
    "def get_mutable_channel_units_using_tracer(model):\n",
    "    units = SequentialMutableChannelUnit.init_from_channel_analyzer(model)\n",
    "    return units\n",
    "\n",
    "\n",
    "model = MyModel()\n",
    "units = get_mutable_channel_units_using_tracer(model)\n",
    "print(f'The model has {len(units)} MutableChannelUnits.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. using config\n",
    "config = {\n",
    "    'init_args': {\n",
    "        'num_channels': 8,\n",
    "    },\n",
    "    'channels': {\n",
    "        'input_related': [{\n",
    "            'name': 'net.conv1',\n",
    "        }],\n",
    "        'output_related': [{\n",
    "            'name': 'net.conv0',\n",
    "        }]\n",
    "    },\n",
    "    'choice': 8\n",
    "}\n",
    "unit=SequentialMutableChannelUnit.init_from_cfg(model, config)\n",
    "print(unit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. using predefined model\n",
    "\n",
    "from mmrazor.models.architectures.dynamic_ops import DynamicConv2d, DynamicLinear\n",
    "from mmrazor.models.mutables import MutableChannelUnit, MutableChannelContainer,SquentialMutableChannel\n",
    "from collections import OrderedDict\n",
    "\n",
    "class MyDynamicModel(BaseModel):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__(None, None)\n",
    "        self.net = nn.Sequential(\n",
    "            OrderedDict([('conv0', DynamicConv2d(3, 8, 3, 1, 1)),\n",
    "                         ('relu', nn.ReLU()),\n",
    "                         ('conv1', DynamicConv2d(8, 16, 3, 1, 1))]))\n",
    "        self.pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.head = DynamicLinear(16, 1000)\n",
    "\n",
    "        # register MutableChannelContainer\n",
    "        MutableChannelUnit._register_channel_container(\n",
    "            self, MutableChannelContainer)\n",
    "        self._register_mutables()\n",
    "\n",
    "    def forward(self, x):\n",
    "        feature = self.net(x)\n",
    "        pool = self.pool(feature).flatten(1)\n",
    "        return self.head(pool)\n",
    "\n",
    "    def _register_mutables(self):\n",
    "        mutable1 = SquentialMutableChannel(8)\n",
    "        mutable2 = SquentialMutableChannel(16)\n",
    "        MutableChannelContainer.register_mutable_channel_to_module(\n",
    "            self.net.conv0, mutable1, is_to_output_channel=True)\n",
    "        MutableChannelContainer.register_mutable_channel_to_module(\n",
    "            self.net.conv1, mutable1, is_to_output_channel=False)\n",
    "\n",
    "        MutableChannelContainer.register_mutable_channel_to_module(\n",
    "            self.net.conv1, mutable2, is_to_output_channel=True)\n",
    "        MutableChannelContainer.register_mutable_channel_to_module(\n",
    "            self.head, mutable2, is_to_output_channel=False)\n",
    "model=MyDynamicModel()\n",
    "units=SequentialMutableChannelUnit.init_from_predefined_model(model)            \n",
    "print(f'The model has {len(units)} MutableChannelUnits.')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('lab2max')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
